Source: LucidJS/src/optvis/renderer.js

import {jitter, fixedScale, compose, standardTransforms} from './transform.js';
import {naiveFromImage, imgLaplacianPyramid,
  randLaplacianPyramid} from './param/image.js';
import {channel, deepdream, neuron, spatial, output,
activationModification, style} from './objectives.js';
import {inverseDecorrelate} from './param/color.js';

import * as tf from '@tensorflow/tfjs';

export const objectiveTypes = {
  LAYER: 'layer',
  CHANNEL: 'channel',
  CLASS: 'class',
  NEURON: 'neuron',
  SPATIAL: 'spatial',
  ACT_ADJUST: 'act. adjust',
  STYLE: 'style',
}

export const loadStates = {
  INITIAL: 'initial',
  LOADING: 'loading',
  LOADED: 'loaded',
  OPTIMIZING: 'optimizing',
}

/**
 * Encapsulates state and methods for various feature visualization techniques.
 */
export class LucidRenderer {
  constructor(){
    this.inputParams = {
      inputSize: 128,
      decorrelate: true,
      pyramidLayers: 4,
      baseImage: null,
    };

    this.objectiveParams = {
      type: objectiveTypes.LAYER,
      layer: '',
      featureMapLayer: '',
      channel: 0,
      neuronX: 0,
      neuronY: 0,
      classInd: 0,
      negative: false,

      pyrLayerWeights: [1, 1, 1, 1],

      jitter: 5,
      learningRate: 0.05,

      activationOverpaint: null,
      activationModifications: {},

      styleImage: null,
      styleLayers: {
        content: [],
        style: [],
      },
      contentImage: null,
    };

    this.featureMapAuxModel = null;
    this.layerOutput = null;
    this.featureMapLayerOutput = null;
    this.isOptimizing = false;

    this.activationShape = null;

    this.iterations = 0;
    this.optimizer = tf.train.adam(this.objectiveParams.learningRate);
    this.ctr = 0;
    this.optimCallback = () => {};
    this.stopOptimCb = () => {};

    tf.setBackend('webgl', true);
  }

  /**
   * Assigns model to the renderer.
   */
  setModel = (model) => {
    this.model = model;
  }

  /**
   * Disposes renderer.
   */
  dispose = () => {
    this.featureMapAuxModel = null;
    this.layerOutput = null;
    this.featureMapLayerOutput = null;
    this.isOptimizing = false;

    this.activationShape = null;

    this.iterations = 0;
    this.optimizer = null;
    this.ctr = 0;
    this.optimCallback = () => {};
  }

  /**
   * initObjectiveParamsForModel - Initializes parameters according to model to provide
   * some fitting default parameters for convenience:
   * - layer: last Conv layer
   * - neuron: central neuron for initial layer
   *
   */
  initObjectiveParamsForModel = () => {
    let firstConvLayer = null;
    for (let layer of this.model.layers) {
      if('kernelSize' in layer) {
        firstConvLayer = layer;
        break;
      }
    }
    if(!firstConvLayer) {
      console.log("Didn't initialize target layer because no Conv2D layer \
      has been found!");
    } else {
      const [x, y] = this.getCentralNeuronCoords(firstConvLayer);
      this.objectiveParams.neuronX = x;
      this.objectiveParams.neuronY = y;

      this.objectiveParams.layer = firstConvLayer.name;
      this.objectiveParams.featureMapLayer = firstConvLayer.name;
    }
  }

  /**
   * setInputParams - Sets input params and causes re-compilation of model,
   * not possible during running optimization.
   *
   * @param  {*} inputParams inputParams object containing all necessary
   * properties
   * @return {*}
   */
  setInputParams = (inputParams) => {
    if(this.isOptimizing){
      throw "Can't change input params during optimization!";
    }
    for (const [key, value] of Object.entries(this.inputParams)) {
      if (!(key in inputParams)) {
        throw "Invalid input params, " + key + "is missing!\n" + inputParams;
      }
    }

    this.inputParams = inputParams;

    this.resizeStyleImage();

    this.compileInput();
  }

  /**
   * Generates input parametrization and initializes regularization transform chain.
   */
  compileInput = () => {
    const w = this.inputParams.inputSize;
    const h = this.inputParams.inputSize;
    const ch = this.model.input.shape[3];
    const pyrL = this.inputParams.pyramidLayers;
    const decorrelate = this.inputParams.decorrelate;

    if(this.inputParams.baseImage){
      //throw "not implemented in renderer yet! needs pyramid parametrisation.";
      const [imgF, trainable] = imgLaplacianPyramid(
        this.inputParams.baseImage, w, h, ch, undefined, decorrelate, pyrL);
        this.paramF = imgF;
        this.trainable = trainable;
    } else {
      const [pyramidF, trainable] = randLaplacianPyramid(w, h,
        ch, 1, 0.01, decorrelate, pyrL);
      this.paramF = pyramidF;
      this.trainable = trainable;
    }

    this.initTransformF();
  }

  /**
   * Stars optimization loop.
   * @param {*} iterations number of iterations to execute
   * @param {*} optimCallback callback to be run after each optimization step
   */
  startOptimization = (iterations=1000, optimCallback=() => {}) => {
    this.iterations=iterations;
    this.optimizer = tf.train.adam(this.objectiveParams.learningRate);
    this.ctr = 0;
    this.optimCallback = optimCallback;

    this.compileObjective();
    this.optimize();
  }

  /**
   * Stops optimization loop.
   * @param {*} cb callback to be run after last optimization step before stopping.
   */
  stopOptimization = (cb) => {
    this.iterations = 0;
    this.ctr = 0;
    if(cb) {
      this.stopOptimCb = () => {
        cb();
        this.stopOptimCb = () => {};
      }
    }
  }

  /**
   * Check if renderer has optimization target layer set.
   */
  canOptimize = () => {
    return this.layer !== ''
  }

  /**
   * Do one optimization step
   */
  optimize = () => {
    this.isOptimizing = true;

    tf.tidy( () =>{

      const negLoss = this.optimizer.minimize(this.lossF, true, this.trainable);

      if (this.ctr++ < this.iterations) {
        this.optimCallback(false);
        requestAnimationFrame(()=>{
          this.optimize()});
      } else {
        this.isOptimizing = false;
        this.stopOptimization();
        this.optimCallback(true);
        this.stopOptimCb();
      }
    });
  }

  /**
   * Returns class prediction of current input image.
   */
  getClassPrediction = () => {
    const weights = this.objectiveParams.pyrLayerWeights;
    return tf.tidy(() => {
      const prediction = this.model.apply(
        this.fixedSizeTransformF(
          this.paramF(this.trainable, weights)), {training: true});

      let classProbs;
      if (Array.isArray(prediction)){
        classProbs = prediction[0].reshape([-1]);
      } else {
        classProbs = prediction.reshape([-1]);
      }

      const topClass = tf.argMax(classProbs);
      return topClass;
    });
  }

  /**
   * Returns number of channels of current layer.
   */
  getChannelNumber = () => {
    if(this.layerOutput) {
      const [b, w, h, ch] = this.layerOutput.shape;
      return ch;
    } else {
      return 0;
    }
  }

  /**
   * Returns activation tensors for currently selected layer.
   */
  getActivationMaps = () => {
    if(this.featureMapAuxModel) {
      return tf.tidy( () =>{
        const featureMaps = this.featureMapAuxModel.apply(
          this.paramF(this.trainable, this.objectiveParams.pyrLayerWeights), {training: false});
        this.activationShape = featureMaps.shape;

        let s = featureMaps.shape;
        let featureMapUint8 = featureMaps.transpose([3, 1, 2, 0])
        .reshape([s[3], s[1], s[2], 1]);

        featureMapUint8 = this.deprocessFeatureMap(featureMapUint8);

        return featureMapUint8;
      });
    } else {
      return null;
    }
  }

  /**
   * Returns activation tensor shape for currently selected layer.
   */
  getActivationShape = () => {
    if(this.featureMapAuxModel) {
      if(!this.activationShape){
        this.getActivationMaps();
      }
      return this.activationShape;
    } else {
      return null;
    }
  }

  /**
   * Returns feature map tensor in 0 - 255 color range.
   */
  deprocessFeatureMap(featureMap) {
    return tf.tidy(() => {
      const {mean, variance} = tf.moments(featureMap);
      let imgToUint8_fm = featureMap.sub(0);
      imgToUint8_fm = imgToUint8_fm.div(tf.sqrt(variance).add(.00001).mul(2));
      return imgToUint8_fm;
    });
  }

  /**
   * Sets the optimization objective type.
   */
  setObjectiveType = (type) => {
    this.objectiveParams.type = type;
  }

  /**
   * Prepares optimization for current objective.
   */
  compileObjective = () => {
    this.compileLossF();

    this.setLayer(this.objectiveParams.layer);
    this.setFeatureMapLayer(this.objectiveParams.featureMapLayer);
  }

  /**
   * Prepares loss function for current objective.
   */
  compileLossF = () => {
    let transformF = this.transformF;
    if(this.objectiveParams.type === objectiveTypes.CLASS){
      transformF = this.fixedSizeTransformF;
    }
    this.lossF = () => {
      return tf.tidy(() => {
        const objF = this.getObjectiveF(this.objectiveParams.type);
        const ret = objF(transformF(this.paramF(
          this.trainable, this.objectiveParams.pyrLayerWeights)));
        return ret;
      });
    };
  }

  /**
   * Returns objective function depending on specified type.
   * @param {string} type optimization type
   */
  getObjectiveF = (type) => {
    const options = {
      layer: this.objectiveParams.layer,
      channel: this.objectiveParams.channel,
      neuron: [
        this.objectiveParams.neuronX,
        this.objectiveParams.neuronY
      ],
      out: this.objectiveParams.classInd,
      neg: this.objectiveParams.negative,
    }
    if(type === objectiveTypes.LAYER){
      return deepdream(this.model, options);
    } else if (type === objectiveTypes.CHANNEL){
      return channel(this.model, options);
    } else if (type === objectiveTypes.NEURON){
      return neuron(this.model, options);
    }  else if (type === objectiveTypes.SPATIAL){
      return spatial(this.model, options);
    } else if (type === objectiveTypes.CLASS){
      return output(this.model, options);
    } else if(type === objectiveTypes.ACT_ADJUST) {
      return activationModification(
        this.model, this.objectiveParams.contentImage,
        this.objectiveParams.activationModifications);
    } else if(type === objectiveTypes.STYLE) {
      const cLrs = this.objectiveParams.styleLayers.content;
      const sLrs = this.objectiveParams.styleLayers.style;
      return style(
        this.model, this.objectiveParams.contentImage,
        this.objectiveParams.styleImage, cLrs, sLrs);
    }
  }



  /**
   * Sets layer for optimization objective. Can't change during
   * optimization. Resets target neuron to central neuron.
   *
   * @param  {type} layer layer to optimize for
   * @return {type}
   */
  setLayer = (layer) => {
    if(this.isOptimizing){
      throw "Can't change target layer during optimization!";
    }
    const changed = this.objectiveParams.layer !== layer;
    const outLayer = this.model.getLayer(layer);
    if(outLayer.outputShape.length !== 4) {
      throw "Can only select layers with 2D ouput!";
      return;
    }
    this.objectiveParams.layer = layer;
    this.layerOutput = outLayer.output;

    const [b, w, h, ch] = this.model.getLayer(layer).outputShape;

    //only layers with 2D output
    if(w && h && changed) {
      const [x, y] = this.getCentralNeuronCoords(outLayer);
      this.objectiveParams.neuronX = x;
      this.objectiveParams.neuronY = y;
    }
  }

  /**
   * Sets layer to output featuremaps for.
   *
   * @param  {type} layer layer to output featuremaps for.
   * @return {type}
   */
  setFeatureMapLayer = (layer) => {
    this.objectiveParams.featureMapLayer = layer;
    const outLayer = this.model.getLayer(layer);
    this.featureMapLayerOutput = outLayer.output;

    this.featureMapAuxModel = tf.model(
      {inputs: this.model.inputs, outputs: this.featureMapLayerOutput});
  }

  /**
   * Sets target channel. Can be changed interactively during
   * optimization.
   *
   * @param  {type} channel target channel
   * @return {type}
   */
  setChannel = (channel) => {
    if(channel < 0){
      throw "Channel index must be zero or positive!";
    }
    this.objectiveParams.channel = channel;
    this.compileLossF();
  }

  /**
   * Sets target neuron. Can be changed interactively during
   * optimization.
   *
   * @param  {type} x x coordinate of neuron in featuremap
   * @param  {type} y y coordinate of neuron in featuremap
   * @return {type}
   */
  setNeuron = (x, y) => {
    if(x < 0 || y < 0) {
      throw "Neuron indices must be zero or positive!";
    }
    this.objectiveParams.neuronX = x;
    this.objectiveParams.neuronY = y;
    this.compileLossF();
  }


  /**
   * Get the 2D coordinates of the central neuron for
   * a specific layer.
   *
   * @param  {type} layer layer to get neuron coordinates for.
   * @return {array}       [x, y]
   */
  getCentralNeuronCoords = (layer) => {
    if(!('kernelSize' in layer)){
      throw "Can't get central neuron coordinates for non-Conv2D layer!";
    }
    const [b, w, h, ch] = this.model.getLayer(layer.name).outputShape;
    const [bI, wI, hI, chI] = this.model.input.shape;
    const poolRatio = wI / w;
    const pooledW = this.inputParams.inputSize / poolRatio;
    return [Math.floor(pooledW/2), Math.floor(pooledW/2)];
  }

  /**
   * Set style image for style transfer.
   */
  setStyleImage = (styleImg) => {
    if(this.objectiveParams.styleImage) {
      this.objectiveParams.styleImage.dispose();
    }
    if(styleImg){
      const w = this.inputParams.inputSize;
      const [f, trainable] = naiveFromImage(
        styleImg, w, w, 3, 1, true);
      const frozenData = f(trainable).dataSync();
      const frozenT = tf.tensor(frozenData, [1, w, w, 3]);

      this.objectiveParams.styleImage = frozenT;
    } else {
      this.objectiveParams.styleImage = null;
    }
    this.compileLossF();
  }

  /**
   * Set content image for style transfer.
   */
  setContentImage = (contentImg) => {
    if(this.objectiveParams.contentImage) {
      this.objectiveParams.contentImage.dispose();
    }
    if(contentImg){
      const data = contentImg.data;
      const ch = data.length / (contentImg.width * contentImg.height);
      const imgShape = [1, contentImg.height, contentImg.width, ch];
      const contentImageT = tf.tensor(contentImg.data,
       imgShape, 'float32');

      this.objectiveParams.contentImage = contentImageT;
    } else {
      this.objectiveParams.contentImage = null;
    }
    this.compileLossF();
  }

  /**
   * Resizes style image to fit input image dimensions.
   */
  resizeStyleImage = () => {
    if(this.objectiveParams.styleImage) {
      const styleImageT = tf.tidy(() => {
        const w = this.inputParams.inputSize;
        const h = this.inputParams.inputSize;

        let refImg = tf.image.resizeBilinear(
          this.objectiveParams.styleImage, [w, h]);
        refImg = refImg.slice([0, 0, 0, 0], [1, h, w, 3]);

        const styleImageData = refImg.dataSync();
        const styleImT = tf.tensor(styleImageData,
         [1, h, w, 3], 'float32');
        return styleImT;
      })

      this.objectiveParams.styleImage.dispose();
      this.objectiveParams.styleImage = styleImageT;
    }
  }

  /**
   * Sets layers that should be considered for style loss.
   */
  setStyleLayers = (styleLayers) => {
    this.objectiveParams.styleLayers = styleLayers;
    this.compileLossF();
  }

  /**
   * Sets target class. Can be changed interactively during
   * optimization.
   *
   * @param  {type} classInd target class index
   */
  setClass = (classInd) => {
    if(classInd < 0){
      throw "Class index must be zero or positive!";
    }
    this.objectiveParams.classInd = classInd;
    this.compileLossF();
  }


  /**
   * Set negative optimization objective status
   *
   * @param  {type} negative negative optimization enabled
   */
  setNegative = (negative) => {
    this.objectiveParams.negative = negative;
    this.compileLossF();
  }


  /**
   * Builds input transform chain
   *
   */
  initTransformF = () => {
    let transforms = [jitter(this.objectiveParams.jitter)];
    let fixedSizeTransforms = [jitter(this.objectiveParams.jitter)];

    const [b, w, h, ch] = this.model.input.shape;
    fixedSizeTransforms.push(fixedScale([w, h]));

    this.fixedSizeTransformF = compose(fixedSizeTransforms);

    if(this.objectiveParams.type === objectiveTypes.CLASS) {
      this.transformF = compose(fixedSizeTransforms);
    } else {
      this.transformF = compose(transforms);
    }
  }


  /**
   * Set input jitter
   *
   * @param  {type} jitter amount of jitter (defaults to 5)
   */
  setJitter = (jitter) => {
    this.objectiveParams.jitter = jitter;
    this.initTransformF();
    this.compileLossF();
  }

  /**
   * Set optimizer learning rate
   *
   * @param  {type} learningRate learning rate
   */
  setLearningRate = (learningRate) => {
    this.objectiveParams.learningRate = learningRate;
    if(this.optimizer){
      this.optimizer.learningRate = learningRate;
    }
  }
}